{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[source](../../api/alibi_detect.od.vaegmm.rst)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Variational Auto-Encoding Gaussian Mixture Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "\n",
    "The Variational Auto-Encoding Gaussian Mixture Model (VAEGMM) Outlier Detector follows the [Deep Autoencoding Gaussian Mixture Model for Unsupervised Anomaly Detection](https://openreview.net/forum?id=BJJLHbb0-) paper but with a [VAE](https://arxiv.org/abs/1312.6114) instead of a regular Auto-Encoder. The encoder compresses the data while the reconstructed instances generated by the decoder are used to create additional features based on the reconstruction error between the input and the reconstructions. These features are combined with encodings and fed into a Gaussian Mixture Model ([GMM](https://en.wikipedia.org/wiki/Mixture_model#Gaussian_mixture_model)). The VAEGMM outlier detector is first trained on a batch of unlabeled, but normal (*inlier*) data. Unsupervised or semi-supervised training is desirable since labeled data is often scarce. The sample energy of the GMM can then be used to determine whether an instance is an outlier (high sample energy) or not (low sample energy). The algorithm is suitable for tabular and image data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Usage\n",
    "\n",
    "### Initialize\n",
    "\n",
    "Parameters:\n",
    "\n",
    "* `threshold`: threshold value for the sample energy above which the instance is flagged as an outlier.\n",
    "\n",
    "* `latent_dim`: latent dimension of the VAE.\n",
    "\n",
    "* `n_gmm`: number of components in the GMM.\n",
    "\n",
    "* `encoder_net`: `tf.keras.Sequential` instance containing the encoder network. Example:\n",
    "\n",
    "```python\n",
    "encoder_net = tf.keras.Sequential(\n",
    "[\n",
    "    InputLayer(input_shape=(n_features,)),\n",
    "    Dense(60, activation=tf.nn.tanh),\n",
    "    Dense(30, activation=tf.nn.tanh),\n",
    "    Dense(10, activation=tf.nn.tanh),\n",
    "    Dense(latent_dim, activation=None)\n",
    "])\n",
    "```\n",
    "\n",
    "* `decoder_net`: `tf.keras.Sequential` instance containing the decoder network. Example:\n",
    "\n",
    "```python\n",
    "decoder_net = tf.keras.Sequential(\n",
    "[\n",
    "    InputLayer(input_shape=(latent_dim,)),\n",
    "    Dense(10, activation=tf.nn.tanh),\n",
    "    Dense(30, activation=tf.nn.tanh),\n",
    "    Dense(60, activation=tf.nn.tanh),\n",
    "    Dense(n_features, activation=None)\n",
    "])\n",
    "```\n",
    "\n",
    "* `gmm_density_net`: layers for the GMM network wrapped in a `tf.keras.Sequential` class. Example:\n",
    "\n",
    "```python\n",
    "gmm_density_net = tf.keras.Sequential(\n",
    "[\n",
    "    InputLayer(input_shape=(latent_dim + 2,)),\n",
    "    Dense(10, activation=tf.nn.tanh),\n",
    "    Dense(n_gmm, activation=tf.nn.softmax)\n",
    "])\n",
    "```\n",
    "\n",
    "* `vaegmm`: instead of using a separate encoder, decoder and GMM density net, the VAEGMM can also be passed as a `tf.keras.Model`.\n",
    "\n",
    "* `samples`: number of samples drawn during detection for each instance to detect.\n",
    "\n",
    "* `beta`: weight on the KL-divergence loss term following the $\\beta$-[VAE](https://openreview.net/forum?id=Sy2fzU9gl) framework. Default equals 1.\n",
    "\n",
    "* `recon_features`: function to extract features from the reconstructed instance by the decoder. Defaults to a combination of the mean squared reconstruction error and the cosine similarity between the original and reconstructed instances by the VAE.\n",
    "\n",
    "* `data_type`: can specify data type added to metadata. E.g. *'tabular'* or *'image'*.\n",
    "\n",
    "Initialized outlier detector example:\n",
    "\n",
    "```python\n",
    "from alibi_detect.od import OutlierVAEGMM\n",
    "\n",
    "od = OutlierVAEGMM(\n",
    "    threshold=7.5,\n",
    "    encoder_net=encoder_net,\n",
    "    decoder_net=decoder_net,\n",
    "    gmm_density_net=gmm_density_net,\n",
    "    latent_dim=4,\n",
    "    n_gmm=2,\n",
    "    samples=10\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fit\n",
    "\n",
    "We then need to train the outlier detector. The following parameters can be specified:\n",
    "\n",
    "* `X`: training batch as a numpy array of preferably normal data.\n",
    "\n",
    "* `loss_fn`: loss function used for training. Defaults to the custom VAEGMM loss which is a combination of the [elbo](https://en.wikipedia.org/wiki/Evidence_lower_bound) loss, sample energy of the GMM and a loss term penalizing small values on the diagonals of the covariance matrices in the GMM to avoid trivial solutions. It is important to balance the loss weights below so no single loss term dominates during the optimization.\n",
    "\n",
    "* `w_recon`: weight on elbo loss term. Defaults to 1e-7.\n",
    "\n",
    "* `w_energy`: weight on sample energy loss term. Defaults to 0.1.\n",
    "\n",
    "* `w_cov_diag`: weight on covariance diagonals. Defaults to 0.005.\n",
    "\n",
    "* `optimizer`: optimizer used for training. Defaults to [Adam](https://arxiv.org/abs/1412.6980) with learning rate 1e-4.\n",
    "\n",
    "* `cov_elbo`: dictionary with covariance matrix options in case the elbo loss function is used. Either use the full covariance matrix inferred from X (*dict(cov_full=None)*), only the variance (*dict(cov_diag=None)*) or a float representing the same standard deviation for each feature (e.g. *dict(sim=.05)*) which is the default.\n",
    "\n",
    "* `epochs`: number of training epochs.\n",
    "\n",
    "* `batch_size`: batch size used during training.\n",
    "\n",
    "* `verbose`: boolean whether to print training progress.\n",
    "\n",
    "* `log_metric`: additional metrics whose progress will be displayed if verbose equals True.\n",
    "\n",
    "\n",
    "```python\n",
    "od.fit(\n",
    "    X_train,\n",
    "    epochs=10,\n",
    "    batch_size=1024\n",
    ")\n",
    "```\n",
    "\n",
    "It is often hard to find a good threshold value. If we have a batch of normal and outlier data and we know approximately the percentage of normal data in the batch, we can infer a suitable threshold:\n",
    "\n",
    "```python\n",
    "od.infer_threshold(\n",
    "    X, \n",
    "    threshold_perc=95\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Detect\n",
    "\n",
    "We detect outliers by simply calling `predict` on a batch of instances `X` to compute the instance level sample energies. We can also return the instance level outlier score by setting `return_instance_score` to True.\n",
    "\n",
    "The prediction takes the form of a dictionary with `meta` and `data` keys. `meta` contains the detector's metadata while `data` is also a dictionary which contains the actual predictions stored in the following keys:\n",
    "\n",
    "* `is_outlier`: boolean whether instances are above the threshold and therefore outlier instances. The array is of shape *(batch size,)*.\n",
    "\n",
    "* `instance_score`: contains instance level scores if `return_instance_score` equals True.\n",
    "\n",
    "\n",
    "```python\n",
    "preds = od.predict(\n",
    "    X,\n",
    "    return_instance_score=True\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Examples\n",
    "\n",
    "### Tabular\n",
    "\n",
    "[Outlier detection on KDD Cup 99](../../examples/od_aegmm_kddcup.ipynb)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
